{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入第三方包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:52:16.629173Z",
     "start_time": "2021-03-15T00:52:16.621194Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import gc\n",
    "import math\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import lightgbm as lgb\n",
    "import xgboost as xgb\n",
    "from catboost import CatBoostRegressor\n",
    "from sklearn.linear_model import SGDRegressor, LinearRegression, Ridge\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "\n",
    "\n",
    "from sklearn.model_selection import StratifiedKFold, KFold\n",
    "from sklearn.metrics import log_loss\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:52:22.085956Z",
     "start_time": "2021-03-15T00:52:19.571864Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>heartbeat_signals</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.9912297987616655,0.9435330436439665,0.764677...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.9714822034884503,0.9289687459588268,0.572932...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>1.0,0.9591487564065292,0.7013782792997189,0.23...</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.9757952826275774,0.9340884687738161,0.659636...</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.0,0.055816398940721094,0.26129357194994196,0...</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id                                  heartbeat_signals  label\n",
       "0   0  0.9912297987616655,0.9435330436439665,0.764677...    0.0\n",
       "1   1  0.9714822034884503,0.9289687459588268,0.572932...    0.0\n",
       "2   2  1.0,0.9591487564065292,0.7013782792997189,0.23...    2.0\n",
       "3   3  0.9757952826275774,0.9340884687738161,0.659636...    0.0\n",
       "4   4  0.0,0.055816398940721094,0.26129357194994196,0...    2.0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train = pd.read_csv('train.csv')\n",
    "test=pd.read_csv('testA.csv')\n",
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:52:41.773931Z",
     "start_time": "2021-03-15T00:52:41.760966Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>heartbeat_signals</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>100000</td>\n",
       "      <td>0.9915713654170097,1.0,0.6318163407681274,0.13...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>100001</td>\n",
       "      <td>0.6075533139615096,0.5417083883163654,0.340694...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>100002</td>\n",
       "      <td>0.9752726292239277,0.6710965234906665,0.686758...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>100003</td>\n",
       "      <td>0.9956348033996116,0.9170249621481004,0.521096...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>100004</td>\n",
       "      <td>1.0,0.8879490481178918,0.745564725322326,0.531...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       id                                  heartbeat_signals\n",
       "0  100000  0.9915713654170097,1.0,0.6318163407681274,0.13...\n",
       "1  100001  0.6075533139615096,0.5417083883163654,0.340694...\n",
       "2  100002  0.9752726292239277,0.6710965234906665,0.686758...\n",
       "3  100003  0.9956348033996116,0.9170249621481004,0.521096...\n",
       "4  100004  1.0,0.8879490481178918,0.745564725322326,0.531..."
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:53:20.837171Z",
     "start_time": "2021-03-15T00:53:20.824203Z"
    }
   },
   "outputs": [],
   "source": [
    "def reduce_mem_usage(df):\n",
    "    start_mem = df.memory_usage().sum() / 1024**2 \n",
    "    print('Memory usage of dataframe is {:.2f} MB'.format(start_mem))\n",
    "    \n",
    "    for col in df.columns:\n",
    "        col_type = df[col].dtype\n",
    "        \n",
    "        if col_type != object:\n",
    "            c_min = df[col].min()\n",
    "            c_max = df[col].max()\n",
    "            if str(col_type)[:3] == 'int':\n",
    "                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:\n",
    "                    df[col] = df[col].astype(np.int8)\n",
    "                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:\n",
    "                    df[col] = df[col].astype(np.int16)\n",
    "                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:\n",
    "                    df[col] = df[col].astype(np.int32)\n",
    "                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:\n",
    "                    df[col] = df[col].astype(np.int64)  \n",
    "            else:\n",
    "                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:\n",
    "                    df[col] = df[col].astype(np.float16)\n",
    "                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:\n",
    "                    df[col] = df[col].astype(np.float32)\n",
    "                else:\n",
    "                    df[col] = df[col].astype(np.float64)\n",
    "        else:\n",
    "            df[col] = df[col].astype('category')\n",
    "\n",
    "    end_mem = df.memory_usage().sum() / 1024**2 \n",
    "    print('Memory usage after optimization is: {:.2f} MB'.format(end_mem))\n",
    "    print('Decreased by {:.1f}%'.format(100 * (start_mem - end_mem) / start_mem))\n",
    "    \n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:54:18.244721Z",
     "start_time": "2021-03-15T00:53:59.807775Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Memory usage of dataframe is 157.93 MB\n",
      "Memory usage after optimization is: 39.67 MB\n",
      "Decreased by 74.9%\n",
      "Memory usage of dataframe is 31.43 MB\n",
      "Memory usage after optimization is: 7.90 MB\n",
      "Decreased by 74.9%\n"
     ]
    }
   ],
   "source": [
    "# 简单预处理\n",
    "train_list = []\n",
    "\n",
    "for items in train.values:\n",
    "    train_list.append([items[0]] + [float(i) for i in items[1].split(',')] + [items[2]])\n",
    "\n",
    "train = pd.DataFrame(np.array(train_list))\n",
    "train.columns = ['id'] + ['s_'+str(i) for i in range(len(train_list[0])-2)] + ['label']\n",
    "train = reduce_mem_usage(train)\n",
    "\n",
    "test_list=[]\n",
    "for items in test.values:\n",
    "    test_list.append([items[0]] + [float(i) for i in items[1].split(',')])\n",
    "\n",
    "test = pd.DataFrame(np.array(test_list))\n",
    "test.columns = ['id'] + ['s_'+str(i) for i in range(len(test_list[0])-1)]\n",
    "test = reduce_mem_usage(test)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:54:57.351196Z",
     "start_time": "2021-03-15T00:54:57.321310Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>s_0</th>\n",
       "      <th>s_1</th>\n",
       "      <th>s_2</th>\n",
       "      <th>s_3</th>\n",
       "      <th>s_4</th>\n",
       "      <th>s_5</th>\n",
       "      <th>s_6</th>\n",
       "      <th>s_7</th>\n",
       "      <th>s_8</th>\n",
       "      <th>...</th>\n",
       "      <th>s_196</th>\n",
       "      <th>s_197</th>\n",
       "      <th>s_198</th>\n",
       "      <th>s_199</th>\n",
       "      <th>s_200</th>\n",
       "      <th>s_201</th>\n",
       "      <th>s_202</th>\n",
       "      <th>s_203</th>\n",
       "      <th>s_204</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.991211</td>\n",
       "      <td>0.943359</td>\n",
       "      <td>0.764648</td>\n",
       "      <td>0.618652</td>\n",
       "      <td>0.379639</td>\n",
       "      <td>0.190796</td>\n",
       "      <td>0.040222</td>\n",
       "      <td>0.026001</td>\n",
       "      <td>0.031708</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>0.971680</td>\n",
       "      <td>0.929199</td>\n",
       "      <td>0.572754</td>\n",
       "      <td>0.178467</td>\n",
       "      <td>0.122986</td>\n",
       "      <td>0.132324</td>\n",
       "      <td>0.094421</td>\n",
       "      <td>0.089600</td>\n",
       "      <td>0.030487</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2.0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.958984</td>\n",
       "      <td>0.701172</td>\n",
       "      <td>0.231812</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.080688</td>\n",
       "      <td>0.128418</td>\n",
       "      <td>0.187500</td>\n",
       "      <td>0.280762</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.0</td>\n",
       "      <td>0.975586</td>\n",
       "      <td>0.934082</td>\n",
       "      <td>0.659668</td>\n",
       "      <td>0.249878</td>\n",
       "      <td>0.237061</td>\n",
       "      <td>0.281494</td>\n",
       "      <td>0.249878</td>\n",
       "      <td>0.249878</td>\n",
       "      <td>0.241455</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.055817</td>\n",
       "      <td>0.261230</td>\n",
       "      <td>0.359863</td>\n",
       "      <td>0.433105</td>\n",
       "      <td>0.453613</td>\n",
       "      <td>0.499023</td>\n",
       "      <td>0.542969</td>\n",
       "      <td>0.616699</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 207 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    id       s_0       s_1       s_2       s_3       s_4       s_5       s_6  \\\n",
       "0  0.0  0.991211  0.943359  0.764648  0.618652  0.379639  0.190796  0.040222   \n",
       "1  1.0  0.971680  0.929199  0.572754  0.178467  0.122986  0.132324  0.094421   \n",
       "2  2.0  1.000000  0.958984  0.701172  0.231812  0.000000  0.080688  0.128418   \n",
       "3  3.0  0.975586  0.934082  0.659668  0.249878  0.237061  0.281494  0.249878   \n",
       "4  4.0  0.000000  0.055817  0.261230  0.359863  0.433105  0.453613  0.499023   \n",
       "\n",
       "        s_7       s_8  ...  s_196  s_197  s_198  s_199  s_200  s_201  s_202  \\\n",
       "0  0.026001  0.031708  ...    0.0    0.0    0.0    0.0    0.0    0.0    0.0   \n",
       "1  0.089600  0.030487  ...    0.0    0.0    0.0    0.0    0.0    0.0    0.0   \n",
       "2  0.187500  0.280762  ...    0.0    0.0    0.0    0.0    0.0    0.0    0.0   \n",
       "3  0.249878  0.241455  ...    0.0    0.0    0.0    0.0    0.0    0.0    0.0   \n",
       "4  0.542969  0.616699  ...    0.0    0.0    0.0    0.0    0.0    0.0    0.0   \n",
       "\n",
       "   s_203  s_204  label  \n",
       "0    0.0    0.0    0.0  \n",
       "1    0.0    0.0    0.0  \n",
       "2    0.0    0.0    2.0  \n",
       "3    0.0    0.0    0.0  \n",
       "4    0.0    0.0    2.0  \n",
       "\n",
       "[5 rows x 207 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:55:36.644040Z",
     "start_time": "2021-03-15T00:55:36.619678Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>s_0</th>\n",
       "      <th>s_1</th>\n",
       "      <th>s_2</th>\n",
       "      <th>s_3</th>\n",
       "      <th>s_4</th>\n",
       "      <th>s_5</th>\n",
       "      <th>s_6</th>\n",
       "      <th>s_7</th>\n",
       "      <th>s_8</th>\n",
       "      <th>...</th>\n",
       "      <th>s_195</th>\n",
       "      <th>s_196</th>\n",
       "      <th>s_197</th>\n",
       "      <th>s_198</th>\n",
       "      <th>s_199</th>\n",
       "      <th>s_200</th>\n",
       "      <th>s_201</th>\n",
       "      <th>s_202</th>\n",
       "      <th>s_203</th>\n",
       "      <th>s_204</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>100000.0</td>\n",
       "      <td>0.991699</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.631836</td>\n",
       "      <td>0.136230</td>\n",
       "      <td>0.041412</td>\n",
       "      <td>0.102722</td>\n",
       "      <td>0.120850</td>\n",
       "      <td>0.123413</td>\n",
       "      <td>0.107910</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>100001.0</td>\n",
       "      <td>0.607422</td>\n",
       "      <td>0.541504</td>\n",
       "      <td>0.340576</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.090698</td>\n",
       "      <td>0.164917</td>\n",
       "      <td>0.195068</td>\n",
       "      <td>0.168823</td>\n",
       "      <td>0.198853</td>\n",
       "      <td>...</td>\n",
       "      <td>0.389893</td>\n",
       "      <td>0.386963</td>\n",
       "      <td>0.367188</td>\n",
       "      <td>0.364014</td>\n",
       "      <td>0.360596</td>\n",
       "      <td>0.357178</td>\n",
       "      <td>0.350586</td>\n",
       "      <td>0.350586</td>\n",
       "      <td>0.350586</td>\n",
       "      <td>0.36377</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>100002.0</td>\n",
       "      <td>0.975098</td>\n",
       "      <td>0.670898</td>\n",
       "      <td>0.686523</td>\n",
       "      <td>0.708496</td>\n",
       "      <td>0.718750</td>\n",
       "      <td>0.716797</td>\n",
       "      <td>0.720703</td>\n",
       "      <td>0.701660</td>\n",
       "      <td>0.596680</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>100003.0</td>\n",
       "      <td>0.995605</td>\n",
       "      <td>0.916992</td>\n",
       "      <td>0.520996</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.221802</td>\n",
       "      <td>0.404053</td>\n",
       "      <td>0.490479</td>\n",
       "      <td>0.527344</td>\n",
       "      <td>0.518066</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>100004.0</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.888184</td>\n",
       "      <td>0.745605</td>\n",
       "      <td>0.531738</td>\n",
       "      <td>0.380371</td>\n",
       "      <td>0.224609</td>\n",
       "      <td>0.091125</td>\n",
       "      <td>0.057648</td>\n",
       "      <td>0.003914</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 206 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         id       s_0       s_1       s_2       s_3       s_4       s_5  \\\n",
       "0  100000.0  0.991699  1.000000  0.631836  0.136230  0.041412  0.102722   \n",
       "1  100001.0  0.607422  0.541504  0.340576  0.000000  0.090698  0.164917   \n",
       "2  100002.0  0.975098  0.670898  0.686523  0.708496  0.718750  0.716797   \n",
       "3  100003.0  0.995605  0.916992  0.520996  0.000000  0.221802  0.404053   \n",
       "4  100004.0  1.000000  0.888184  0.745605  0.531738  0.380371  0.224609   \n",
       "\n",
       "        s_6       s_7       s_8  ...     s_195     s_196     s_197     s_198  \\\n",
       "0  0.120850  0.123413  0.107910  ...  0.000000  0.000000  0.000000  0.000000   \n",
       "1  0.195068  0.168823  0.198853  ...  0.389893  0.386963  0.367188  0.364014   \n",
       "2  0.720703  0.701660  0.596680  ...  0.000000  0.000000  0.000000  0.000000   \n",
       "3  0.490479  0.527344  0.518066  ...  0.000000  0.000000  0.000000  0.000000   \n",
       "4  0.091125  0.057648  0.003914  ...  0.000000  0.000000  0.000000  0.000000   \n",
       "\n",
       "      s_199     s_200     s_201     s_202     s_203    s_204  \n",
       "0  0.000000  0.000000  0.000000  0.000000  0.000000  0.00000  \n",
       "1  0.360596  0.357178  0.350586  0.350586  0.350586  0.36377  \n",
       "2  0.000000  0.000000  0.000000  0.000000  0.000000  0.00000  \n",
       "3  0.000000  0.000000  0.000000  0.000000  0.000000  0.00000  \n",
       "4  0.000000  0.000000  0.000000  0.000000  0.000000  0.00000  \n",
       "\n",
       "[5 rows x 206 columns]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练数据/测试数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:56:15.971953Z",
     "start_time": "2021-03-15T00:56:15.876344Z"
    }
   },
   "outputs": [],
   "source": [
    "x_train = train.drop(['id','label'], axis=1)\n",
    "y_train = train['label']\n",
    "x_test=test.drop(['id'], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:57:01.757175Z",
     "start_time": "2021-03-15T00:57:01.750341Z"
    }
   },
   "outputs": [],
   "source": [
    "def abs_sum(y_pre,y_tru):\n",
    "    y_pre=np.array(y_pre)\n",
    "    y_tru=np.array(y_tru)\n",
    "    loss=sum(sum(abs(y_pre-y_tru)))\n",
    "    return loss\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:57:42.940805Z",
     "start_time": "2021-03-15T00:57:42.928082Z"
    }
   },
   "outputs": [],
   "source": [
    "def cv_model(clf, train_x, train_y, test_x, clf_name):\n",
    "    folds = 5\n",
    "    seed = 2021\n",
    "    kf = KFold(n_splits=folds, shuffle=True, random_state=seed)\n",
    "    test = np.zeros((test_x.shape[0],4))\n",
    "\n",
    "    cv_scores = []\n",
    "    onehot_encoder = OneHotEncoder(sparse=False)\n",
    "    for i, (train_index, valid_index) in enumerate(kf.split(train_x, train_y)):\n",
    "        print('************************************ {} ************************************'.format(str(i+1)))\n",
    "        trn_x, trn_y, val_x, val_y = train_x.iloc[train_index], train_y[train_index], train_x.iloc[valid_index], train_y[valid_index]\n",
    "        \n",
    "        if clf_name == \"lgb\":\n",
    "            train_matrix = clf.Dataset(trn_x, label=trn_y)\n",
    "            valid_matrix = clf.Dataset(val_x, label=val_y)\n",
    "\n",
    "            params = {\n",
    "                'boosting_type': 'gbdt',\n",
    "                'objective': 'multiclass',\n",
    "                'num_class': 4,\n",
    "                'num_leaves': 2 ** 5,\n",
    "                'feature_fraction': 0.8,\n",
    "                'bagging_fraction': 0.8,\n",
    "                'bagging_freq': 4,\n",
    "                'learning_rate': 0.1,\n",
    "                'seed': seed,\n",
    "                'nthread': 28,\n",
    "                'n_jobs':24,\n",
    "                'verbose': -1,\n",
    "            }\n",
    "\n",
    "            model = clf.train(params, \n",
    "                      train_set=train_matrix, \n",
    "                      valid_sets=valid_matrix, \n",
    "                      num_boost_round=2000, \n",
    "                      verbose_eval=100, \n",
    "                      early_stopping_rounds=200)\n",
    "            val_pred = model.predict(val_x, num_iteration=model.best_iteration)\n",
    "            test_pred = model.predict(test_x, num_iteration=model.best_iteration) \n",
    "            \n",
    "        val_y=np.array(val_y).reshape(-1, 1)\n",
    "        val_y = onehot_encoder.fit_transform(val_y)\n",
    "        print('预测的概率矩阵为：')\n",
    "        print(test_pred)\n",
    "        test += test_pred\n",
    "        score=abs_sum(val_y, val_pred)\n",
    "        cv_scores.append(score)\n",
    "        print(cv_scores)\n",
    "    print(\"%s_scotrainre_list:\" % clf_name, cv_scores)\n",
    "    print(\"%s_score_mean:\" % clf_name, np.mean(cv_scores))\n",
    "    print(\"%s_score_std:\" % clf_name, np.std(cv_scores))\n",
    "    test=test/kf.n_splits\n",
    "\n",
    "    return test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-03-15T00:58:22.378103Z",
     "start_time": "2021-03-15T00:58:22.373222Z"
    }
   },
   "outputs": [],
   "source": [
    "def lgb_model(x_train, y_train, x_test):\n",
    "    lgb_test = cv_model(lgb, x_train, y_train, x_test, \"lgb\")\n",
    "    return lgb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-03-15T00:53:32.384Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "************************************ 1 ************************************\n",
      "Training until validation scores don't improve for 200 rounds\n",
      "[100]\tvalid_0's multi_logloss: 0.0525735\n",
      "[200]\tvalid_0's multi_logloss: 0.0422444\n",
      "[300]\tvalid_0's multi_logloss: 0.0407076\n",
      "[400]\tvalid_0's multi_logloss: 0.0420398\n",
      "Early stopping, best iteration is:\n",
      "[289]\tvalid_0's multi_logloss: 0.0405457\n",
      "预测的概率矩阵为：\n",
      "[[9.99969791e-01 2.85197261e-05 1.00341946e-06 6.85357631e-07]\n",
      " [7.93287264e-05 7.69060914e-04 9.99151590e-01 2.00810971e-08]\n",
      " [5.75356884e-07 5.04051497e-08 3.15322414e-07 9.99999059e-01]\n",
      " ...\n",
      " [6.79267940e-02 4.30206297e-04 9.31640185e-01 2.81516302e-06]\n",
      " [9.99960477e-01 3.94098074e-05 8.34030725e-08 2.94638661e-08]\n",
      " [9.88705846e-01 2.14081630e-03 6.67418381e-03 2.47915423e-03]]\n",
      "[607.0736049372186]\n",
      "************************************ 2 ************************************\n",
      "[LightGBM] [Warning] num_threads is set with nthread=28, will be overridden by n_jobs=24. Current value: num_threads=24\n",
      "Training until validation scores don't improve for 200 rounds\n",
      "[100]\tvalid_0's multi_logloss: 0.0566626\n",
      "[200]\tvalid_0's multi_logloss: 0.0450852\n"
     ]
    }
   ],
   "source": [
    "lgb_test = lgb_model(x_train, y_train, x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-03-15T00:53:33.065Z"
    }
   },
   "outputs": [],
   "source": [
    "lgb_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-03-15T00:53:33.810Z"
    }
   },
   "outputs": [],
   "source": [
    "temp=pd.DataFrame(lgb_test)\n",
    "temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2021-03-15T00:53:34.680Z"
    }
   },
   "outputs": [],
   "source": [
    "result=pd.read_csv('sample_submit.csv')\n",
    "result['label_0']=temp[0]\n",
    "result['label_1']=temp[1]\n",
    "result['label_2']=temp[2]\n",
    "result['label_3']=temp[3]\n",
    "result.to_csv('submit.csv',index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
